{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Python 机器学习实战 ——代码样例\n",
    "\n",
    "# 第十七章 支持向量机"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 使用支持向量机分类图片\n",
    "\n",
    "Digits 数据集主要由 8乘以8 大小的阿拉伯数字图片构成，数据集包含 1797 个样本，分为 0-9 十种不同的数字，数据集的样本图片如下图所示：\n",
    "\n",
    "![Digits 数据集](./imgs/digits 数据集.png)\n",
    "\n",
    "在分类过程中首先读取数据集，并将其随机分割成 80%的训练集和 20%的测试集，然后在训练集上采用高斯核函数的 SVM分类器，并在验证集上测试准确率，最后输出混淆矩阵和分类报告。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "E:\\Software\\Anaconda\\lib\\site-packages\\sklearn\\cross_validation.py:41: DeprecationWarning: This module was deprecated in version 0.18 in favor of the model_selection module into which all the refactored classes and functions are moved. Also note that the interface of the new CV iterators are different from that of this module. This module will be removed in 0.20.\n",
      "  \"This module will be removed in 0.20.\", DeprecationWarning)\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAWcAAAB0CAYAAABZjfMMAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAACdtJREFUeJzt3W2MXGUZxvHrolWhAt0tolECLSWCCcRuWhKNQigCQmJw\nCwlVoiT4QleCRoLoVoTQwge6fDDlA8QGJVVqNF2lLQZjbDHbGDDGLmyNr0Ta8mpA0m4pQiGR2w/n\nNFnXhfN09szsM6f/X7LJzs4953nm3u51zsycp8cRIQBAXo6a6QkAAP4f4QwAGSKcASBDhDMAZIhw\nBoAMEc4AkKEjKpxt32r7/pmeRxPR2/aht+2Va387Es62z7H9iO1x2y/Z/q3tJZ0Yewq1nNhtu9f2\nJtuv2N5t+8o6ttvCPJrY2+ts/8H2Qdv31bHNFufRqN7afqft79veY3u/7cdsX1LH5FqcT6P6K0m2\n77f9z/I5/c32l1rd1uw6JvR2bB8n6ReSBiQNS3qnpHMlvd7usdvsHkkHJZ0oabGkh2yPRcRfOzWB\nBvf2OUm3S7pY0jEzMYGG9na2pKclnRsRz9j+lKSNts+KiKc7OZGG9leS7pB0TUQctH26pO22H4uI\nxw93Q504cj5dUkTExii8HhHbIuJPkmR7oe2Hyz3ni7Y32D7+0IPLo9Ibbe+0fcD2vbbfa/uXtl+2\n/Wvbc8va+bbftH2N7efKr2+81cRsf7Tcc++z/bjt81KekO05ki6XdHNEvBYRj0jaIumqafSpFY3r\nrYontDkiHpS0t/XWTFvjehsRr0bEbRHxTHn7IUm7Jc3E0Wrj+qviCf0lIg4e2pSKI/LTWmmQIqKt\nX5KOk/QvSeslXSKpZ9L9p0m6QMVe/QRJI5K+O+H+3ZIelfQeSe+X9IKkHZI+rGJv+7CkW8ra+ZLe\nlPRjSUdLOkvSi5I+Ud5/q6Qfld+fJOklSReXty8ob59Q3h6U9OBbPKc+Sa9M+tkNkra0u59N7+2k\n+d8u6b5O9vRI6W1Z+z5Jr0o6nf7W119Jd0v6dznmDklzWupRh34RZ0i6T8VLqjdUHGWe+Ba1/ZJG\nJ/0Srpxw+2eS7p5w+6uSHpj0S/jghPuHJN07xS/hW5J+OGnsX0m6KuH5nCPp+Uk/+7Kk38zAP/JG\n9XbSY2YsnI+A3s6WtFXSPfS3Lf21pI9JuknSrFb605EPBCPi7xHxxYg4RcVe6wOS1kpS+VLkJ7af\ntT0uaYOKveFEL0z4/rUpbh87cThJz064/VQ53mTzJS23vbf82ifp4yr2wlVekXT8pJ/NlXQg4bG1\namBvs9HU3tp2Od/XJX0t9XF1a2p/y+cWEfGopJMlXXs4jz2k46fSRcQTKl7KnFX+6A4Ve7UzI6JH\n0udV7HVaZRUNOeQUSc9PUfeMir3lvPKrNyKOi4g7E8Z4QtJs2xPfS1ok6c8tz7oGDeltlhrW2x+o\nCLrLI+I/rU+5Pg3r70Sz1eJ7zm0PZ9tn2L7B9knl7ZMlXSnpd2XJsSqORA+UNd+sYdhbbB9j+0xJ\nX5D00ylqNki61PYnbR9l+2jb59meam/6PyLiVUkPSLrN9hzb50i6VFJHz5VsYm8lyfYs20dLmqVi\nJ/gu27NqmHuyBvf2e5I+JOnTEfFGDXNuSRP7a/tE25+x/e7ysRdL+qykba1MthNHzgckfUTS720f\nUPEm/h8l3Vjev1rFp8XjKk6t+fmkx08+/zDlfMTtkv6h4j21OyPi4ckFEfGsivexblLxwcRT5ZyO\nkiTb37b90NuMcZ2kOSo+WNgg6SvRwdPoSk3t7c0qPqgalPS58vvvJMytTo3rre1TJK1Q8YH2C+VZ\nDi97Zs7Rb1x/yzlcq+Loe6+kOyV9PYqzYg6byzevG8H2fEm7JL0jIt6c6fk0Cb1tH3rbXt3a3yYu\n357O+1J4e/S2fehte3Vdf5sYzs15KZAfets+9La9uq6/jXpbAwCaoolHzgDQ9er8j49qOQQfHh6u\nrBkcHKysueiii5LGW7NmTWVNb29v0rYStPq+V0df3ixdurSyZnx8vLJm1apVlTXLli1LmFGyVvrb\n0d6OjIxU1qT0pK+vr5axDsOM/tsdGhqqrFm5cmVlzamnnpo03ujoaGVNu3OBI2cAyBDhDAAZIpwB\nIEOEMwBkiHAGgAwRzgCQIcIZADLU9gu8Hq6Uc5h3795dWbNv376k8ebNm1dZs3HjxsqaK664Imm8\nbtDT01NZs3379sqaus7p7QZjY2NJdeeff35lzdy5cytr9uzZkzReN0g5Pznlb3DdunWVNQMDA0lz\nSjnP+cILL0zaVqs4cgaADBHOAJAhwhkAMkQ4A0CGCGcAyBDhDAAZIpwBIEOEMwBkiHAGgAx1dIVg\nyqqblNV/Tz75ZGXNwoULk+aUcsWUlHl3wwrB1FVsdV1BI+VqHU2xefPmpLpFixZV1qSsmly9enXS\neN1gxYoVlTUpK4eXLFlSWZN6JZR2r/5LwZEzAGSIcAaADBHOAJAhwhkAMkQ4A0CGCGcAyBDhDAAZ\nIpwBIEMdXYSScumoxYsXV9akLjBJkXLiejdYu3ZtZc2qVauStrV///5pzqawdOnSWrbTDa6//vqk\nugULFtSyrf7+/qTxukHK3/OuXbsqa1IWsKUuLknJqt7e3qRttYojZwDIEOEMABkinAEgQ4QzAGSI\ncAaADBHOAJAhwhkAMkQ4A0CGsluEknJlkjrlcLJ5HVIWLlx99dVJ26rr+Y6Pj9eynZmW8jxSFgFJ\n6VdMqbJ+/fpattMtUhaq7N27t7ImdRFKSt22bdsqa6bzt8SRMwBkiHAGgAwRzgCQIcIZADJEOANA\nhghnAMgQ4QwAGSKcASBDHV2EknJC9ujoaC1jpSwukaQdO3ZU1ixfvny60zkijY2NVdb09fV1YCbT\nk3IFmbvuuqu28TZt2lRZ09PTU9t4TZGSLykLRyRpYGCgsmZoaKiyZs2aNUnjTYUjZwDIEOEMABki\nnAEgQ4QzAGSIcAaADBHOAJAhwhkAMkQ4A0CGCGcAyFBHVwimXGomZcXe8PBwLTWpBgcHa9sWuk/K\n5b1GRkaStrVz587Kmssuu6yypr+/v7Im9bJky5YtS6qbSStXrqysSbm0VOrK4a1bt1bWtHvlMEfO\nAJAhwhkAMkQ4A0CGCGcAyBDhDAAZIpwBIEOEMwBkiHAGgAxltwgl5dIvKYtCzj777KQ51XVZrG6Q\nemmjlAUOW7ZsqaxJWZiRulBiJqVcSivlklypdSmXxUrp/4IFCxJm1B2LUFIuQbVixYraxktZYLJu\n3braxpsKR84AkCHCGQAyRDgDQIYIZwDIEOEMABkinAEgQ4QzAGSIcAaADDkiZnoOAIBJOHIGgAwR\nzgCQIcIZADJEOANAhghnAMgQ4QwAGSKcASBDhDMAZIhwBoAMEc4AkCHCGQAyRDgDQIYIZwDIEOEM\nABkinAEgQ4QzAGSIcAaADBHOAJAhwhkAMkQ4A0CG/guuTD70IhJY5AAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<matplotlib.figure.Figure at 0x22aa1378470>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# 将matplotlib的图表直接嵌入到Notebook之中。\n",
    "%matplotlib inline  \n",
    "\n",
    "# 导入相关的包。\n",
    "\n",
    "from sklearn.svm import SVC\n",
    "from sklearn.datasets import load_digits\n",
    "from sklearn import cross_validation\n",
    "from sklearn import metrics\n",
    "from matplotlib import pyplot as plt\n",
    "import random\n",
    "\n",
    "# 载入样本数据集。\n",
    "\n",
    "digits = load_digits()\n",
    "\n",
    "# 获取数据集图片及对应的标签。\n",
    "\n",
    "images = digits.images\n",
    "labels = digits.target\n",
    "\n",
    "# 绘制样本图片。\n",
    "\n",
    "for index, (image, label) in enumerate(zip(images[:4], labels[:4])):\n",
    "    plt.subplot(1, 4, index + 1)\n",
    "    plt.axis('off')\n",
    "    plt.imshow(image,cmap=plt.cm.gray_r,interpolation='nearest')\n",
    "    plt.title('Sample: %i' % label)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "分类报告： SVC(C=1.0, cache_size=200, class_weight=None, coef0=0.0,\n",
      "  decision_function_shape='ovr', degree=3, gamma=0.001, kernel='rbf',\n",
      "  max_iter=-1, probability=False, random_state=None, shrinking=True,\n",
      "  tol=0.001, verbose=False):\n",
      "             precision    recall  f1-score   support\n",
      "\n",
      "          0       1.00      1.00      1.00        40\n",
      "          1       1.00      1.00      1.00        32\n",
      "          2       1.00      1.00      1.00        37\n",
      "          3       0.97      1.00      0.98        30\n",
      "          4       1.00      1.00      1.00        25\n",
      "          5       1.00      1.00      1.00        36\n",
      "          6       1.00      1.00      1.00        45\n",
      "          7       1.00      0.98      0.99        45\n",
      "          8       1.00      1.00      1.00        39\n",
      "          9       0.97      0.97      0.97        30\n",
      "\n",
      "avg / total       0.99      0.99      0.99       359\n",
      "\n",
      "\n",
      "混淆矩阵:\n",
      "[[40  0  0  0  0  0  0  0  0  0]\n",
      " [ 0 32  0  0  0  0  0  0  0  0]\n",
      " [ 0  0 37  0  0  0  0  0  0  0]\n",
      " [ 0  0  0 30  0  0  0  0  0  0]\n",
      " [ 0  0  0  0 25  0  0  0  0  0]\n",
      " [ 0  0  0  0  0 36  0  0  0  0]\n",
      " [ 0  0  0  0  0  0 45  0  0  0]\n",
      " [ 0  0  0  0  0  0  0 44  0  1]\n",
      " [ 0  0  0  0  0  0  0  0 39  0]\n",
      " [ 0  0  0  1  0  0  0  0  0 29]]\n"
     ]
    }
   ],
   "source": [
    "# 转换图片为向量。\n",
    "\n",
    "n_samples = len(images)\n",
    "image_vectors = images.reshape((n_samples, -1))\n",
    "\n",
    "# 分割训练集与测试集。\n",
    "\n",
    "sample_index = list(range(n_samples))\n",
    "test_size = int(n_samples * 0.2)\n",
    "random.shuffle(sample_index)\n",
    "train_index, test_index = sample_index[test_size:], sample_index[:test_size]\n",
    "X_train, Y_train = image_vectors[train_index],labels[train_index]\n",
    "X_test, Y_test = image_vectors[test_index],labels[test_index] \n",
    "\n",
    "# 构建 SVC 分类器训练并预报，使用 rbf 核函数。\n",
    "\n",
    "classifier = SVC(kernel='rbf', C=1.0, gamma=0.001)   \n",
    "classifier.fit(X_train, Y_train)\n",
    "prediction = classifier.predict(X_test)\n",
    "\n",
    "# 输出结果及混淆矩阵。\n",
    "\n",
    "print(\"分类报告： %s:\\n%s\\n\" % \n",
    "      (classifier,metrics.classification_report(Y_test,prediction)))\n",
    "print(\"混淆矩阵:\\n%s\" % metrics.confusion_matrix(Y_test, prediction))\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.1"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
